from distutils.util import strtobool
import os
import random
import argparse

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
import gymnasium as gym


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def boolean_argument(value):
    """Convert a string value to boolean."""
    return bool(strtobool(value))


def seed(seed, deterministic_execution=False):
    print('Seeding: random, torch, numpy')
    random.seed(seed)
    torch.manual_seed(seed)
    torch.random.manual_seed(seed)
    np.random.seed(seed)

    if deterministic_execution:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    else:
        print('Note that due to parallel processing \
               results will be similar but not identical.'
              'Use only one process and set --deterministic_execution to True \
               if you want identical results (only recommended for debugging).')


class FeatureExtractor(nn.Module):
    """ 
    Single layered feed-forward network 
    for embedding of states/actions/rewards
    """
    def __init__(self, input_size, output_size, activation_function):
        super(FeatureExtractor, self).__init__()
        self.output_size = output_size
        self.activation_function = activation_function
        if self.output_size != 0:
            self.fc = nn.Linear(input_size, output_size)
        else:
            self.fc = None

    def forward(self, inputs):
        if self.output_size != 0:
            return self.activation_function(self.fc(inputs))
        else:
            return torch.zeros(0, ).to(device)


def get_states_from_state_dicts(
        state_dicts, 
        env_name, 
        keep_time
    ):
        '''
        state_dicts: Dict or OrderedDict
        env_name: str
        keep_time: boolean, whether to keep time
        '''
        
        if env_name.split('-')[0] in ['CoupledBlockDF', 'UncoupledBlockDF', 'RandomWalkDF']:
            state_name = 'trial'
        else:
            state_name = 'timestep'
        
        num_processes = state_dicts[state_name].shape[0]
        # get curr_states based on environment types
        if 'Timed' in env_name:
            if keep_time:
                states = np.concatenate(
                    [state_dicts[state_name], state_dicts['go_cue'].reshape(num_processes,1)], 
                    axis=1
                )
            else:
                if num_processes > 1:  # already an array
                    states = state_dicts['go_cue'].reshape(num_processes,1)
                else:
                    states = np.array([state_dicts['go_cue']]).reshape(num_processes,1)
        else:
            if keep_time:
                states = state_dicts[state_name]
            else:
                states = np.array([])
        
        return states


def plot_training_curves(
    args: argparse.Namespace,  # for hyperparameters
    out_dir: str,  # path to save the plots
    episode_returns: list,  # list of episode returns 
    actor_losses: list,  # list of actor losses 
    critic_losses: list,  # list of critic losses
    policy_entropies: list,  # list of policy entropies
    activity_l2_loss: list,  # list of activity l2 losses
    rolling_length: int = 10  # rolling average length
):

    fig, axs = plt.subplots(
        nrows=5, ncols=1, 
        figsize=(12, 12),
        dpi=300
    )
    fig.suptitle(
        f"Training plots for model {out_dir.split('/')[-2]}/{out_dir.split('/')[-1]}\n"
        f"(n_envs={args.num_processes}, n_steps_per_update={args.policy_num_steps_per_update})"
    )

    # -- POLICY --
    # episode return
    ax = axs[0]
    ax.set_title("Episode Returns")
    episode_returns_moving_average = (
        np.convolve(np.array(episode_returns), np.ones(rolling_length),
            mode="valid") 
        / rolling_length 
    )
    ax.plot(
        np.arange(len(episode_returns_moving_average)),
        episode_returns_moving_average,
    )
    ax.set_xlabel("Number of updates")

    # actor loss
    ax = axs[1]
    ax.set_title("Actor Loss")
    actor_losses_moving_average = (
        np.convolve(np.array(actor_losses).flatten(), np.ones(rolling_length), mode="valid")
        / rolling_length
    )
    ax.plot(actor_losses_moving_average)
    ax.set_xlabel("Number of updates")

    # critic loss
    ax = axs[2]
    ax.set_title("Critic Loss")
    critic_losses_moving_average = (
        np.convolve(
            np.array(critic_losses).flatten(), np.ones(rolling_length), 
                mode="valid")
        / rolling_length
    )
    ax.plot(critic_losses_moving_average)
    ax.set_xlabel("Number of updates")

    # entropy
    ax = axs[3]
    ax.set_title("Entropy")
    entropy_moving_average = (
        np.convolve(np.array(policy_entropies), np.ones(rolling_length), 
            mode="valid")
        / rolling_length
    )
    ax.plot(entropy_moving_average)
    ax.set_xlabel("Number of updates")

    # activity_l2_loss
    ax = axs[4]
    ax.set_title("Activity L2 loss")
    if 'noisy' in args.exp_label:
        activity_l2_loss_moving_average = (
            np.convolve(np.array(activity_l2_loss), np.ones(rolling_length), 
                mode="valid")
            / rolling_length
        )
        ax.plot(activity_l2_loss_moving_average)
    ax.set_xlabel("Number of updates")

    plt.tight_layout()
    fig.savefig(os.path.join(out_dir, 'training_curves.png'))


def plot_evaluation_curves(
    out_dir: str,  # path to save the plots
    eval_epoch_ids: list,  # list of evaluation epoch ids
    empirical_return_avgs: list,  # list of average empirical returns
    empirical_return_stds: list,  # list of std of empirical returns
    num_eval_runs: int  # number of evaluation runs
):
    fig, axs = plt.subplots(
        nrows=1, ncols=1, 
        figsize=(6, 4),
        dpi=300
    )
    axs.set_title(
        f"Eval empirical returns, model {out_dir.split('/')[-2]}/{out_dir.split('/')[-1]}"
    )
    axs.errorbar(
        eval_epoch_ids,
        empirical_return_avgs,
        yerr=empirical_return_stds / np.sqrt(num_eval_runs)
    )
    axs.set_xlabel("Epoch")
    axs.set_ylabel("episodic return")

    fig.tight_layout()
    fig.savefig(os.path.join(out_dir, 'eval_empirical_returns.png'))



#################################################
# STATE MACHINE ANALYSIS  
#################################################
# https://github.com/openai/baselines/blob/master/baselines/common/tf_util.py#L87
def init_normc_(weight, gain=1):
    weight.normal_(0, 1)
    weight *= gain / torch.sqrt(weight.pow(2).sum(1, keepdim=True))

def init(module, weight_init, bias_init, gain=1.0):
    weight_init(module.weight.data, gain=gain)
    bias_init(module.bias.data)
    return module

class StateSpaceMapper(nn.Module):
    def __init__(
        self,
        # inputs
        source_dim,
        # network
        hidden_layers,
        activation_function,  # tanh, relu, leaky-relu
        initialization_method, # orthogonal, normc
        # outputs
        target_dim
    ):
        super(StateSpaceMapper, self).__init__()

        self.source_dim = source_dim
        self.target_dim = target_dim

        # set activation function
        if activation_function == 'tanh':
            self.activation_function = nn.Tanh()
        elif activation_function == 'relu':
            self.activation_function = nn.ReLU()
        elif activation_function == 'leaky-relu':
            self.activation_function = nn.LeakyReLU()
        else:
            raise ValueError

        # set initialization method
        if initialization_method == 'normc':
            self.init_ = lambda m: init(
                m, init_normc_, 
                lambda x: nn.init.constant_(x, 0), 
                nn.init.calculate_gain(activation_function)
            )
        elif initialization_method == 'orthogonal':
            self.init_ = lambda m: init(
                m, nn.init.orthogonal_, 
                lambda x: nn.init.constant_(x, 0), 
                nn.init.calculate_gain(activation_function)
            )

        curr_input_dim = self.source_dim
        self.mapper_layers, mapper_fc_final_dim = self.gen_fc_layers(
            hidden_layers, curr_input_dim)

        # output layer
        self.mapper_output = self.init_(nn.Linear(
            mapper_fc_final_dim, self.target_dim))


    def gen_fc_layers(self, layers, curr_input_dim):
        fc_layers = nn.ModuleList([])
        for i in range(len(layers)):
            fc = self.init_(nn.Linear(curr_input_dim, layers[i]))
            fc_layers.append(fc)
            curr_input_dim = layers[i]
        return fc_layers, curr_input_dim
    
    def forward_mapper(self, inputs):
        h = inputs
        for i in range(len(self.mapper_layers)):
            h = self.mapper_layers[i](h)
            h = self.activation_function(h)
        return h

    def forward(self, source_states):
        # input shape: sequence_len x batch_size x feature_dim
        # forward through the mapper
        mapper_h = self.forward_mapper(source_states)
        mapped_states = self.mapper_output(mapper_h)

        return mapped_states


class StateMapperTrainer:
    def __init__(
        self,
        state_mapper,
        # optimization
        lr,
        eps,
        anneal_lr,
        train_steps
    ):

        self.state_mapper = state_mapper.to(device)

        # initialize optimizer
        self.criterion = nn.MSELoss().to(device)
        self.optimizer = torch.optim.Adam(
            self.state_mapper.parameters(), 
            lr=lr,
            eps=eps
        )
        # learning rate annealing
        self.lr_scheduler_policy = None
        if anneal_lr:
            lam = lambda f: 1 - f / train_steps
            self.lr_scheduler = optim.lr_scheduler.LambdaLR(
                self.optimizer, lr_lambda=lam)

    def train(
        self, 
        dataset, 
        num_epochs,
        batch_size,
        data_sampler=None
    ):
        losses_mse = []
        self.state_mapper.train()

        for epoch in range(num_epochs):
            loss_mse_epoch = 0
            # load data
            if data_sampler is not None:
                train_loader = DataLoader(
                    dataset, 
                    sampler=data_sampler, 
                    batch_size=batch_size
                )
            else:
                train_loader = DataLoader(
                    dataset, 
                    batch_size=batch_size, 
                    shuffle=True
                )
            
            for x_, t_ in train_loader:
                x = x_.type(torch.FloatTensor).to(device)  # (batch_size, source_dim)
                t = t_.type(torch.FloatTensor).to(device)  # (batch_size, target_dim)

                # -- training --
                self.optimizer.zero_grad()

                y = self.state_mapper(x)  

                loss_mse_batch = self.criterion(y, t)
                
                loss_mse_batch.backward()
                self.optimizer.step()
                
                loss_mse_epoch += loss_mse_batch.data.cpu().numpy()

            losses_mse.append(loss_mse_epoch / float(len(dataset)))
            
            # verbose
            if epoch % 100 == 0:
                print(f'epoch: {epoch}')
                print(f' loss: {loss_mse_epoch}')
                
        # after training
        losses_mse = np.array(losses_mse)

        return losses_mse
